{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Interval based time series classification in sktime\n",
    "\n",
    "Interval based approaches look at phase dependent intervals of the full series, calculating summary statistics from selected subseries to be used in classification.\n",
    "\n",
    "Currently 5 univariate interval based approaches are implemented in sktime. Time Series Forest (TSF) \\[1\\], the Random Interval Spectral Ensemble (RISE) \\[2\\], Supervised Time Series Forest (STSF) \\[3\\], the Canonical Interval Forest (CIF) \\[4\\] and the Diverse Representation Canonical Interval Forest (DrCIF). Both CIF and DrCIF have multivariate capabilities.\n",
    "\n",
    "In this notebook, we will demonstrate how to use these classifiers on the ItalyPowerDemand and BasicMotions datasets.\n",
    "\n",
    "#### References:\n",
    "\n",
    "\\[1\\] Deng, H., Runger, G., Tuv, E., & Vladimir, M. (2013). A time series forest for classification and feature extraction. Information Sciences, 239, 142-153.\n",
    "\n",
    "\\[2\\] Flynn, M., Large, J., & Bagnall, T. (2019). The contract random interval spectral ensemble (c-RISE): the effect of contracting a classifier on accuracy. In International Conference on Hybrid Artificial Intelligence Systems (pp. 381-392). Springer, Cham.\n",
    "\n",
    "\\[3\\] Cabello, N., Naghizade, E., Qi, J., & Kulik, L. (2020). Fast and Accurate Time Series Classification Through Supervised Interval Search. In IEEE International Conference on Data Mining.\n",
    "\n",
    "\\[4\\] Middlehurst, M., Large, J., & Bagnall, A. (2020). The Canonical Interval Forest (CIF) Classifier for Time Series Classification. arXiv preprint arXiv:2008.09172.\n",
    "\n",
    "\\[5\\] Lubba, C. H., Sethi, S. S., Knaute, P., Schultz, S. R., Fulcher, B. D., & Jones, N. S. (2019). catch22: CAnonical Time-series CHaracteristics. Data Mining and Knowledge Discovery, 33(6), 1821-1852."
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 1. Imports"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn import metrics\n",
    "from sklearn.pipeline import Pipeline\n",
    "\n",
    "from sktime.classification.interval_based import (\n",
    "    CanonicalIntervalForest,\n",
    "    DrCIF,\n",
    "    RandomIntervalSpectralEnsemble,\n",
    "    SupervisedTimeSeriesForest,\n",
    "    TimeSeriesForestClassifier,\n",
    ")\n",
    "from sktime.datasets import load_basic_motions, load_italy_power_demand\n",
    "from sktime.transformations.panel.compose import ColumnConcatenator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train, y_train = load_italy_power_demand(split=\"train\", return_X_y=True)\n",
    "X_test, y_test = load_italy_power_demand(split=\"test\", return_X_y=True)\n",
    "X_test = X_test[:50]\n",
    "y_test = y_test[:50]\n",
    "\n",
    "print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)\n",
    "\n",
    "X_train_mv, y_train_mv = load_basic_motions(split=\"train\", return_X_y=True)\n",
    "X_test_mv, y_test_mv = load_basic_motions(split=\"test\", return_X_y=True)\n",
    "\n",
    "X_train_mv = X_train_mv[:50]\n",
    "y_train_mv = y_train_mv[:50]\n",
    "X_test_mv = X_test_mv[:50]\n",
    "y_test_mv = y_test_mv[:50]\n",
    "\n",
    "print(X_train_mv.shape, y_train_mv.shape, X_test_mv.shape, y_test_mv.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Time Series Forest (TSF)\n",
    "\n",
    "TSF is an ensemble of tree classifiers built on the summary statistics of randomly selected intervals.\n",
    "For each tree sqrt(series_length) intervals are randomly selected.\n",
    "From each of these intervals the mean, standard deviation and slope is extracted from each time series and concatenated into a feature vector.\n",
    "These new features are then used to build a tree, which is added to the ensemble."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tsf = TimeSeriesForestClassifier(n_estimators=50, random_state=47)\n",
    "tsf.fit(X_train, y_train)\n",
    "\n",
    "tsf_preds = tsf.predict(X_test)\n",
    "print(\"TSF Accuracy: \" + str(metrics.accuracy_score(y_test, tsf_preds)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "tsf = Pipeline(\n",
    "    [\n",
    "        (\"column_concatenar\", ColumnConcatenator()),\n",
    "        (\"classify\", TimeSeriesForestClassifier(n_estimators=50, random_state=47)),\n",
    "    ]\n",
    ")\n",
    "tsf.fit(X_train_mv, y_train_mv)\n",
    "\n",
    "tsf_preds = tsf.predict(X_test_mv)\n",
    "print(\"TSF Accuracy: \" + str(metrics.accuracy_score(y_test_mv, tsf_preds)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "temporal_feature_importance = tsf[\"classify\"].feature_importances_\n",
    "separators = range(0, tsf[\"classify\"].series_length, len(X_train_mv.iloc[0, 0]))\n",
    "\n",
    "ax = temporal_feature_importance.plot(figsize=(20, 10))\n",
    "for index, separator in enumerate(separators):\n",
    "    ax.vlines(\n",
    "        separator,\n",
    "        temporal_feature_importance.min().min(),\n",
    "        temporal_feature_importance.max().max(),\n",
    "        color=\"r\",\n",
    "        alpha=0.3,\n",
    "    )\n",
    "    ax.text(\n",
    "        separator, temporal_feature_importance.max().max(), X_train_mv.columns[index]\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "X_train_mv_columns = list(X_train_mv.columns)\n",
    "np.random.shuffle(X_train_mv_columns)\n",
    "\n",
    "X_train_shuffled = X_train_mv[X_train_mv_columns]\n",
    "X_train_shuffled.columns = X_train_mv.columns\n",
    "\n",
    "X_test_shuffled = X_test_mv[X_train_mv_columns]\n",
    "X_test_shuffled.columns = X_test_mv.columns\n",
    "\n",
    "tsf = Pipeline(\n",
    "    [\n",
    "        (\"column_concatenator\", ColumnConcatenator()),\n",
    "        (\"classify\", TimeSeriesForestClassifier(n_estimators=50, random_state=47)),\n",
    "    ]\n",
    ")\n",
    "tsf.fit(X_train_shuffled, y_train_mv)\n",
    "\n",
    "tsf_preds = tsf.predict(X_test_shuffled)\n",
    "print(\"TSF Accuracy: \" + str(metrics.accuracy_score(y_test_mv, tsf_preds)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "temporal_feature_importance = tsf[\"classify\"].feature_importances_\n",
    "separators = range(0, tsf[\"classify\"].series_length, len(X_train_mv.iloc[0, 0]))\n",
    "\n",
    "ax = temporal_feature_importance.plot(figsize=(20, 10))\n",
    "for index, separator in enumerate(separators):\n",
    "    ax.vlines(\n",
    "        separator,\n",
    "        temporal_feature_importance.min().min(),\n",
    "        temporal_feature_importance.max().max(),\n",
    "        color=\"r\",\n",
    "        alpha=0.3,\n",
    "    )\n",
    "    ax.text(\n",
    "        separator, temporal_feature_importance.max().max(), X_train_mv_columns[index]\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "tsf = Pipeline(\n",
    "    [\n",
    "        (\"column_concatenator\", ColumnConcatenator()),\n",
    "        (\n",
    "            \"classify\",\n",
    "            TimeSeriesForestClassifier(\n",
    "                n_estimators=50, random_state=47, inner_series_length=100\n",
    "            ),\n",
    "        ),\n",
    "    ]\n",
    ")\n",
    "tsf.fit(X_train_mv, y_train_mv)\n",
    "\n",
    "tsf_preds = tsf.predict(X_test_mv)\n",
    "print(\"TSF Accuracy: \" + str(metrics.accuracy_score(y_test_mv, tsf_preds)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "temporal_feature_importance = tsf[\"classify\"].feature_importances_\n",
    "separators = range(0, tsf[\"classify\"].series_length, len(X_train_mv.iloc[0, 0]))\n",
    "\n",
    "ax = temporal_feature_importance.plot(figsize=(20, 10))\n",
    "for index, separator in enumerate(separators):\n",
    "    ax.vlines(\n",
    "        separator,\n",
    "        temporal_feature_importance.min().min(),\n",
    "        temporal_feature_importance.max().max(),\n",
    "        color=\"r\",\n",
    "        alpha=0.3,\n",
    "    )\n",
    "    ax.text(\n",
    "        separator, temporal_feature_importance.max().max(), X_train_mv.columns[index]\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "X_train_mv_columns = list(X_train_mv.columns)\n",
    "np.random.shuffle(X_train_mv_columns)\n",
    "\n",
    "X_train_shuffled = X_train_mv[X_train_mv_columns]\n",
    "X_train_shuffled.columns = X_train_mv.columns\n",
    "\n",
    "X_test_shuffled = X_test_mv[X_train_mv_columns]\n",
    "X_test_shuffled.columns = X_test_mv.columns\n",
    "\n",
    "tsf = Pipeline(\n",
    "    [\n",
    "        (\"column_concatenator\", ColumnConcatenator()),\n",
    "        (\n",
    "            \"classify\",\n",
    "            TimeSeriesForestClassifier(\n",
    "                n_estimators=50, random_state=47, inner_series_length=100\n",
    "            ),\n",
    "        ),\n",
    "    ]\n",
    ")\n",
    "tsf.fit(X_train_shuffled, y_train_mv)\n",
    "\n",
    "tsf_preds = tsf.predict(X_test_shuffled)\n",
    "print(\"TSF Accuracy: \" + str(metrics.accuracy_score(y_test_mv, tsf_preds)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "temporal_feature_importance = tsf[\"classify\"].feature_importances_\n",
    "separators = range(0, tsf[\"classify\"].series_length, len(X_train_mv.iloc[0, 0]))\n",
    "\n",
    "ax = temporal_feature_importance.plot(figsize=(20, 10))\n",
    "for index, separator in enumerate(separators):\n",
    "    ax.vlines(\n",
    "        separator,\n",
    "        temporal_feature_importance.min().min(),\n",
    "        temporal_feature_importance.max().max(),\n",
    "        color=\"r\",\n",
    "        alpha=0.3,\n",
    "    )\n",
    "    ax.text(\n",
    "        separator, temporal_feature_importance.max().max(), X_train_mv_columns[index]\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Random Interval Spectral Ensemble (RISE)\n",
    "\n",
    "RISE is a tree based interval ensemble aimed at classifying audio data. Unlike TSF, it uses a single interval for each tree, and it uses spectral features rather than summary statistics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "rise = RandomIntervalSpectralEnsemble(n_estimators=50, random_state=47)\n",
    "rise.fit(X_train, y_train)\n",
    "\n",
    "rise_preds = rise.predict(X_test)\n",
    "print(\"RISE Accuracy: \" + str(metrics.accuracy_score(y_test, rise_preds)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Supervised Time Series Forest (STSF)\n",
    "\n",
    "STSF makes a number of adjustments from the original TSF algorithm. A supervised method of selecting intervals replaces random selection. Features are extracted from intervals generated from additional representations in periodogram and 1st order differences. Median, min, max and interquartile range are included in the summary statistics extracted."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "stsf = SupervisedTimeSeriesForest(n_estimators=50, random_state=47)\n",
    "stsf.fit(X_train, y_train)\n",
    "\n",
    "stsf_preds = stsf.predict(X_test)\n",
    "print(\"STSF Accuracy: \" + str(metrics.accuracy_score(y_test, stsf_preds)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Canonical Interval Forest (CIF)\n",
    "\n",
    "CIF extends from the TSF algorithm. In addition to the 3 summary statistics used by TSF, CIF makes use of the features from the `Catch22` \\[5\\] transform.\n",
    "To increase the diversity of the ensemble, the number of TSF and catch22 attributes is randomly subsampled per tree.\n",
    "\n",
    "### Univariate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-12-19T14:32:06.471294Z",
     "iopub.status.busy": "2020-12-19T14:32:06.467536Z",
     "iopub.status.idle": "2020-12-19T14:32:10.775056Z",
     "shell.execute_reply": "2020-12-19T14:32:10.775964Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "cif = CanonicalIntervalForest(n_estimators=50, att_subsample_size=8, random_state=47)\n",
    "cif.fit(X_train, y_train)\n",
    "\n",
    "cif_preds = cif.predict(X_test)\n",
    "print(\"CIF Accuracy: \" + str(metrics.accuracy_score(y_test, cif_preds)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multivariate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "cif_m = CanonicalIntervalForest(n_estimators=50, att_subsample_size=8, random_state=47)\n",
    "cif_m.fit(X_train_mv, y_train_mv)\n",
    "\n",
    "cif_m_preds = cif_m.predict(X_test_mv)\n",
    "print(\"CIF Accuracy: \" + str(metrics.accuracy_score(y_test_mv, cif_m_preds)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Diverse Representation Canonical Interval Forest (DrCIF)\n",
    "\n",
    "DrCIF makes use of the periodogram and differences representations used by STSF as well as the addition summary statistics in CIF.\n",
    "\n",
    "### Univariate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "drcif = DrCIF(n_estimators=5, att_subsample_size=10, random_state=47)\n",
    "drcif.fit(X_train, y_train)\n",
    "\n",
    "drcif_preds = drcif.predict(X_test)\n",
    "print(\"DrCIF Accuracy: \" + str(metrics.accuracy_score(y_test, drcif_preds)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Multivariate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "drcif_m = DrCIF(n_estimators=5, att_subsample_size=10, random_state=47)\n",
    "drcif_m.fit(X_train_mv, y_train_mv)\n",
    "\n",
    "drcif_m_preds = drcif_m.predict(X_test_mv)\n",
    "print(\"DrCIF Accuracy: \" + str(metrics.accuracy_score(y_test_mv, drcif_m_preds)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
